/*
 * Copyright 2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.connector.accessor.Accessor;
import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
import com.alibaba.cloud.ai.connector.bo.ResultSetBO;
import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.constant.Constant;
import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.model.execution.ExecutionStep;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StepResultUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.Map;

import static com.alibaba.cloud.ai.constant.Constant.SQL_EXECUTE_NODE_EXCEPTION_OUTPUT;
import static com.alibaba.cloud.ai.constant.Constant.SQL_EXECUTE_NODE_OUTPUT;

/**
 * SQL execution node that executes SQL queries against the database.
 *
 * This node is responsible for: - Executing SQL queries generated by previous nodes -
 * Handling query results and errors - Providing streaming feedback to users during
 * execution - Managing step-by-step result accumulation
 *
 * @author zhangshenghang
 */
public class SqlExecuteNode extends AbstractPlanBasedNode {

	private static final Logger logger = LoggerFactory.getLogger(SqlExecuteNode.class);

	private final DbConfig dbConfig;

	private final Accessor dbAccessor;

	public SqlExecuteNode(Accessor dbAccessor, DbConfig dbConfig) {
		super();
		this.dbAccessor = dbAccessor;
		this.dbConfig = dbConfig;
	}

	@Override
	public Map<String, Object> apply(OverAllState state) throws Exception {
		logNodeEntry();

		ExecutionStep executionStep = getCurrentExecutionStep(state);
		Integer currentStep = getCurrentStepNumber(state);

		ExecutionStep.ToolParameters toolParameters = executionStep.getToolParameters();
		String sqlQuery = toolParameters.getSqlQuery();

		logger.info("Executing SQL query: {}", sqlQuery);
		logger.info("Step description: {}", toolParameters.getDescription());

		return executeSqlQuery(state, currentStep, sqlQuery);
	}

	/**
	 * Executes the SQL query against the database and handles the results.
	 *
	 * This method follows the business-logic-first pattern: 1. Execute the actual SQL
	 * query immediately 2. Process and store the results 3. Create streaming output for
	 * user experience only
	 * @param state The overall state containing execution context
	 * @param currentStep The current step number in the execution plan
	 * @param sqlQuery The SQL query to execute
	 * @return Map containing the generator for streaming output
	 */
	@SuppressWarnings("unchecked")
	private Map<String, Object> executeSqlQuery(OverAllState state, Integer currentStep, String sqlQuery) {
		// Execute business logic first - actual SQL execution
		DbQueryParameter dbQueryParameter = new DbQueryParameter();
		dbQueryParameter.setSql(sqlQuery);

		try {
			// Execute SQL query and get results immediately
			ResultSetBO resultSetBO = dbAccessor.executeSqlAndReturnObject(dbConfig, dbQueryParameter);
			String jsonStr = resultSetBO.toJsonStr();

			// Update step results with the query output
			Map<String, String> existingResults = StateUtils.getObjectValue(state, SQL_EXECUTE_NODE_OUTPUT, Map.class,
					new HashMap<>());
			Map<String, String> updatedResults = StepResultUtils.addStepResult(existingResults, currentStep, jsonStr);

			logger.info("SQL execution successful, result count: {}",
					resultSetBO.getData() != null ? resultSetBO.getData().size() : 0);

			// Prepare the final result object
			// 将SQL查询结果的List存储起来，供代码运行节点使用
			Map<String, Object> result = Map.of(SQL_EXECUTE_NODE_OUTPUT, updatedResults,
					SQL_EXECUTE_NODE_EXCEPTION_OUTPUT, "", Constant.SQL_RESULT_LIST_MEMORY, resultSetBO.getData());

			// Create display flux for user experience only
			Flux<ChatResponse> displayFlux = Flux.create(emitter -> {
				emitter.next(ChatResponseUtil.createCustomStatusResponse("开始执行SQL..."));
				emitter.next(ChatResponseUtil.createCustomStatusResponse("执行SQL查询"));
				emitter.next(ChatResponseUtil.createCustomStatusResponse("```" + sqlQuery + "```"));
				emitter.next(ChatResponseUtil.createCustomStatusResponse("执行SQL完成"));
				emitter.complete();
			});

			// Create generator using utility class, returning pre-computed business logic
			// result
			var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state,
					v -> result, displayFlux, StreamResponseType.EXECUTE_SQL);

			return Map.of(SQL_EXECUTE_NODE_OUTPUT, generator);
		}
		catch (Exception e) {
			String errorMessage = e.getMessage();
			logger.error("SQL execution failed - SQL: [{}] ", sqlQuery, e);

			// Prepare error result
			Map<String, Object> errorResult = Map.of(SQL_EXECUTE_NODE_EXCEPTION_OUTPUT, errorMessage);

			// Create error display flux
			Flux<ChatResponse> errorDisplayFlux = Flux.create(emitter -> {
				emitter.next(ChatResponseUtil.createCustomStatusResponse("开始执行SQL..."));
				emitter.next(ChatResponseUtil.createCustomStatusResponse("执行SQL查询"));
				emitter.next(ChatResponseUtil.createCustomStatusResponse("SQL执行失败: " + errorMessage));
				emitter.complete();
			});

			// Create error generator using utility class
			var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state,
					v -> errorResult, errorDisplayFlux, StreamResponseType.EXECUTE_SQL);

			return Map.of(SQL_EXECUTE_NODE_EXCEPTION_OUTPUT, generator);
		}
	}

}
